{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " # Patch Time Series Mixer for Transfer Learning across datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " The `PatchTSMixer` model was proposed in [TSMixer: Lightweight MLP-Mixer Model for Multivariate\n",
    " Time Series Forecasting](https://arxiv.org/pdf/2306.09364.pdf) by Vijay Ekambaram, Arindam Jati,\n",
    " Nam Nguyen, Phanwadee Sinthong and Jayant Kalagnanam.\n",
    "\n",
    " `PatchTSMixer` is a time-series foundation modeling approach based on segmentation of time series into subseries-level patches and channel-independence.\n",
    "\n",
    " In this notebook, we will demonstrate the tranfer learning capability of the `PatchTSMixer` model.\n",
    " We will pretrain the model for a forecasting task on a `source` dataset. Then, we will use the\n",
    " pretrained model for a zero-shot forecasting on a `target` dataset. The zero-shot forecasting\n",
    " performance will denote the `test` performance of the model in the `target` domain, without any\n",
    " training on the target domain. Subsequently, we will do linear probing and (then) finetuning of\n",
    " the pretrained model on the `train` part of the target data, and will validate the forecasting\n",
    " performance on the `test` part of the target data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "2023-12-11 01:25:50.313015: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "2023-12-11 01:25:50.313102: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "2023-12-11 01:25:50.313132: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "2023-12-11 01:25:51.234452: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
     ]
    }
   ],
   "source": [
    "# Standard\n",
    "import os\n",
    "import random\n",
    "\n",
    "# Third Party\n",
    "from transformers import (\n",
    "    EarlyStoppingCallback,\n",
    "    PatchTSMixerConfig,\n",
    "    PatchTSMixerForPrediction,\n",
    "    Trainer,\n",
    "    TrainingArguments,\n",
    ")\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "\n",
    "# First Party\n",
    "from tsfm_public.toolkit.dataset import ForecastDFDataset\n",
    "from tsfm_public.toolkit.time_series_preprocessor import TimeSeriesPreprocessor\n",
    "from tsfm_public.toolkit.util import select_by_index"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ## Set seed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "SEED = 42\n",
    "torch.manual_seed(SEED)\n",
    "random.seed(SEED)\n",
    "np.random.seed(SEED)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load and prepare datasets\n",
    "\n",
    "In the next cell, please adjust the following parameters to suit your application:\n",
    "- `PRETRAIN_AGAIN`: Set this to `True` if you want to perform pretraining again. Note that this might take some time depending on the GPU availability. Otherwise, the already pretrained model will be used.\n",
    "- `dataset_path`: path to local .csv file, or web address to a csv file for the data of interest. Data is loaded with pandas, so anything supported by\n",
    "`pd.read_csv` is supported: (https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html).\n",
    "- `timestamp_column`: column name containing timestamp information, use None if there is no such column\n",
    "- `id_columns`: List of column names specifying the IDs of different time series. If no ID column exists, use []\n",
    "- `forecast_columns`: List of columns to be modeled\n",
    "- `context_length`: The amount of historical data used as input to the model. Windows of the input time series data with length equal to\n",
    "context_length will be extracted from the input dataframe. In the case of a multi-time series dataset, the context windows will be created\n",
    "so that they are contained within a single time series (i.e., a single ID).\n",
    "- `forecast_horizon`: Number of time stamps to forecast in future.\n",
    "- `train_start_index`, `train_end_index`: the start and end indices in the loaded data which delineate the training data.\n",
    "- `valid_start_index`, `valid_end_index`: the start and end indices in the loaded data which delineate the validation data.\n",
    "- `test_start_index`, `test_end_index`: the start and end indices in the loaded data which delineate the test data.\n",
    "- `patch_length`: The patch length for the `PatchTSMixer` model. Recommended to have a value so that `context_length` is divisible by it.\n",
    "- `num_workers`: Number of dataloder workers in pytorch dataloader.\n",
    "- `batch_size`: Batch size.\n",
    "The data is first loaded into a Pandas dataframe and split into training, validation, and test parts. Then the pandas dataframes are converted\n",
    "to the appropriate torch dataset needed for training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "PRETRAIN_AGAIN = True\n",
    "# Download ECL data from https://github.com/zhouhaoyi/Informer2020\n",
    "dataset_path = \"~/Downloads/ECL.csv\"\n",
    "timestamp_column = \"date\"\n",
    "id_columns = []\n",
    "\n",
    "context_length = 512\n",
    "forecast_horizon = 96\n",
    "patch_length = 8\n",
    "num_workers = 16  # Reduce this if you have low number of CPU cores\n",
    "batch_size = 64  # Adjust according to GPU memory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "if PRETRAIN_AGAIN:\n",
    "    data = pd.read_csv(\n",
    "        dataset_path,\n",
    "        parse_dates=[timestamp_column],\n",
    "    )\n",
    "    forecast_columns = list(data.columns[1:])\n",
    "\n",
    "    # get split\n",
    "    num_train = int(len(data) * 0.7)\n",
    "    num_test = int(len(data) * 0.2)\n",
    "    num_valid = len(data) - num_train - num_test\n",
    "    border1s = [\n",
    "        0,\n",
    "        num_train - context_length,\n",
    "        len(data) - num_test - context_length,\n",
    "    ]\n",
    "    border2s = [num_train, num_train + num_valid, len(data)]\n",
    "\n",
    "    train_start_index = border1s[0]  # None indicates beginning of dataset\n",
    "    train_end_index = border2s[0]\n",
    "\n",
    "    # we shift the start of the evaluation period back by context length so that\n",
    "    # the first evaluation timestamp is immediately following the training data\n",
    "    valid_start_index = border1s[1]\n",
    "    valid_end_index = border2s[1]\n",
    "\n",
    "    test_start_index = border1s[2]\n",
    "    test_end_index = border2s[2]\n",
    "\n",
    "    train_data = select_by_index(\n",
    "        data,\n",
    "        id_columns=id_columns,\n",
    "        start_index=train_start_index,\n",
    "        end_index=train_end_index,\n",
    "    )\n",
    "    valid_data = select_by_index(\n",
    "        data,\n",
    "        id_columns=id_columns,\n",
    "        start_index=valid_start_index,\n",
    "        end_index=valid_end_index,\n",
    "    )\n",
    "    test_data = select_by_index(\n",
    "        data,\n",
    "        id_columns=id_columns,\n",
    "        start_index=test_start_index,\n",
    "        end_index=test_end_index,\n",
    "    )\n",
    "\n",
    "    tsp = TimeSeriesPreprocessor(\n",
    "        timestamp_column=timestamp_column,\n",
    "        id_columns=id_columns,\n",
    "        input_columns=forecast_columns,\n",
    "        output_columns=forecast_columns,\n",
    "        scaling=True,\n",
    "    )\n",
    "    tsp.train(train_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "if PRETRAIN_AGAIN:\n",
    "    train_dataset = ForecastDFDataset(\n",
    "        tsp.preprocess(train_data),\n",
    "        id_columns=id_columns,\n",
    "        timestamp_column=\"date\",\n",
    "        input_columns=forecast_columns,\n",
    "        output_columns=forecast_columns,\n",
    "        context_length=context_length,\n",
    "        prediction_length=forecast_horizon,\n",
    "    )\n",
    "    valid_dataset = ForecastDFDataset(\n",
    "        tsp.preprocess(valid_data),\n",
    "        id_columns=id_columns,\n",
    "        timestamp_column=\"date\",\n",
    "        input_columns=forecast_columns,\n",
    "        output_columns=forecast_columns,\n",
    "        context_length=context_length,\n",
    "        prediction_length=forecast_horizon,\n",
    "    )\n",
    "    test_dataset = ForecastDFDataset(\n",
    "        tsp.preprocess(test_data),\n",
    "        id_columns=id_columns,\n",
    "        timestamp_column=\"date\",\n",
    "        input_columns=forecast_columns,\n",
    "        output_columns=forecast_columns,\n",
    "        context_length=context_length,\n",
    "        prediction_length=forecast_horizon,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ## Configure the PatchTSMixer model\n",
    "\n",
    " The settings below control the different components in the PatchTSMixer model.\n",
    "  - `num_input_channels`: the number of input channels (or dimensions) in the time series data. This is\n",
    "    automatically set to the number for forecast columns.\n",
    "  - `context_length`: As described above, the amount of historical data used as input to the model.\n",
    "  - `prediction_length`: This is same as the forecast horizon as decribed above.\n",
    "  - `patch_length`: The length of the patches extracted from the context window (of length `context_length``).\n",
    "  - `patch_stride`: The stride used when extracting patches from the context window.\n",
    "  - `d_model`: Hidden feature dimension of the model.\n",
    "  - `num_layers`: The number of model layers.\n",
    "  - `dropout`: Dropout probability for all fully connected layers in the encoder.\n",
    "  - `head_dropout`: Dropout probability used in the head of the model.\n",
    "  - `mode`: PatchTSMixer operating mode. \"common_channel\"/\"mix_channel\". Common-channel works in channel-independent mode. For pretraining, use \"common_channel\".\n",
    "  - `scaling`: Per-widow standard scaling. Recommended value: \"std\".\n",
    "\n",
    " We recommend that you only adjust the values in the next cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "if PRETRAIN_AGAIN:\n",
    "    config = PatchTSMixerConfig(\n",
    "        context_length=context_length,\n",
    "        prediction_length=forecast_horizon,\n",
    "        patch_length=patch_length,\n",
    "        num_input_channels=len(forecast_columns),\n",
    "        patch_stride=patch_length,\n",
    "        d_model=16,\n",
    "        num_layers=8,\n",
    "        expansion_factor=2,\n",
    "        dropout=0.2,\n",
    "        head_dropout=0.2,\n",
    "        mode=\"common_channel\",\n",
    "        scaling=\"std\",\n",
    "    )\n",
    "    model = PatchTSMixerForPrediction(config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ## Train model\n",
    "\n",
    " Trains the PatchTSMixer model based on the direct forecasting strategy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='2450' max='7000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [2450/7000 21:35 < 40:08, 1.89 it/s, Epoch 35/100]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.247100</td>\n",
       "      <td>0.141067</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.168600</td>\n",
       "      <td>0.127757</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.156500</td>\n",
       "      <td>0.122327</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.150300</td>\n",
       "      <td>0.118918</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.146000</td>\n",
       "      <td>0.116496</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>0.143100</td>\n",
       "      <td>0.114968</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>0.140800</td>\n",
       "      <td>0.113678</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>0.139200</td>\n",
       "      <td>0.113057</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>0.137900</td>\n",
       "      <td>0.112405</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>0.136900</td>\n",
       "      <td>0.112225</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11</td>\n",
       "      <td>0.136100</td>\n",
       "      <td>0.112087</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12</td>\n",
       "      <td>0.135400</td>\n",
       "      <td>0.112330</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13</td>\n",
       "      <td>0.134700</td>\n",
       "      <td>0.111778</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14</td>\n",
       "      <td>0.134100</td>\n",
       "      <td>0.111702</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>15</td>\n",
       "      <td>0.133700</td>\n",
       "      <td>0.110964</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>16</td>\n",
       "      <td>0.133100</td>\n",
       "      <td>0.111164</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>17</td>\n",
       "      <td>0.132800</td>\n",
       "      <td>0.111063</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>18</td>\n",
       "      <td>0.132400</td>\n",
       "      <td>0.111088</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>19</td>\n",
       "      <td>0.132100</td>\n",
       "      <td>0.110905</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20</td>\n",
       "      <td>0.131800</td>\n",
       "      <td>0.110844</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>21</td>\n",
       "      <td>0.131300</td>\n",
       "      <td>0.110831</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>22</td>\n",
       "      <td>0.131100</td>\n",
       "      <td>0.110278</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>23</td>\n",
       "      <td>0.130700</td>\n",
       "      <td>0.110591</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>24</td>\n",
       "      <td>0.130600</td>\n",
       "      <td>0.110319</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>25</td>\n",
       "      <td>0.130300</td>\n",
       "      <td>0.109900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>26</td>\n",
       "      <td>0.130000</td>\n",
       "      <td>0.109982</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>27</td>\n",
       "      <td>0.129900</td>\n",
       "      <td>0.109975</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>28</td>\n",
       "      <td>0.129600</td>\n",
       "      <td>0.110128</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>29</td>\n",
       "      <td>0.129300</td>\n",
       "      <td>0.109995</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>30</td>\n",
       "      <td>0.129100</td>\n",
       "      <td>0.109868</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>31</td>\n",
       "      <td>0.129000</td>\n",
       "      <td>0.109928</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>32</td>\n",
       "      <td>0.128700</td>\n",
       "      <td>0.109823</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>33</td>\n",
       "      <td>0.128500</td>\n",
       "      <td>0.109863</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>34</td>\n",
       "      <td>0.128400</td>\n",
       "      <td>0.109794</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>35</td>\n",
       "      <td>0.128100</td>\n",
       "      <td>0.109945</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n"
     ]
    }
   ],
   "source": [
    "if PRETRAIN_AGAIN:\n",
    "    training_args = TrainingArguments(\n",
    "        output_dir=\"./checkpoint/patchtsmixer/electricity/pretrain/output/\",\n",
    "        overwrite_output_dir=True,\n",
    "        learning_rate=0.001,\n",
    "        num_train_epochs=100,  # For a quick test of this notebook, set it to 1\n",
    "        do_eval=True,\n",
    "        evaluation_strategy=\"epoch\",\n",
    "        per_device_train_batch_size=batch_size,\n",
    "        per_device_eval_batch_size=batch_size,\n",
    "        dataloader_num_workers=num_workers,\n",
    "        report_to=\"tensorboard\",\n",
    "        save_strategy=\"epoch\",\n",
    "        logging_strategy=\"epoch\",\n",
    "        save_total_limit=3,\n",
    "        logging_dir=\"./checkpoint/patchtsmixer/electricity/pretrain/logs/\",  # Make sure to specify a logging directory\n",
    "        load_best_model_at_end=True,  # Load the best model when training ends\n",
    "        metric_for_best_model=\"eval_loss\",  # Metric to monitor for early stopping\n",
    "        greater_is_better=False,  # For loss\n",
    "        label_names=[\"future_values\"],\n",
    "        # max_steps=20,\n",
    "    )\n",
    "\n",
    "    # Create the early stopping callback\n",
    "    early_stopping_callback = EarlyStoppingCallback(\n",
    "        early_stopping_patience=10,  # Number of epochs with no improvement after which to stop\n",
    "        early_stopping_threshold=0.0001,  # Minimum improvement required to consider as improvement\n",
    "    )\n",
    "\n",
    "    # define trainer\n",
    "    trainer = Trainer(\n",
    "        model=model,\n",
    "        args=training_args,\n",
    "        train_dataset=train_dataset,\n",
    "        eval_dataset=valid_dataset,\n",
    "        callbacks=[early_stopping_callback],\n",
    "    )\n",
    "\n",
    "    # pretrain\n",
    "    trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ## Evaluate model on the test set of the `source` domain\n",
    " Note that this is not the target metric to judge in this task."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='21' max='21' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [21/21 00:03]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test result:\n",
      "{'eval_loss': 0.12884521484375, 'eval_runtime': 5.7532, 'eval_samples_per_second': 897.763, 'eval_steps_per_second': 3.65, 'epoch': 35.0}\n"
     ]
    }
   ],
   "source": [
    "if PRETRAIN_AGAIN:\n",
    "    results = trainer.evaluate(test_dataset)\n",
    "    print(\"Test result:\")\n",
    "    print(results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ## Save model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "if PRETRAIN_AGAIN:\n",
    "    save_dir = \"patchtsmixer/electricity/model/pretrain/\"\n",
    "    os.makedirs(save_dir, exist_ok=True)\n",
    "    trainer.save_model(save_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Transfer Learing on `ETTh2` data. All evaluations are on the `test` part of the `ETTh2` data.\n",
    "Step 1: Directly evaluate the electricity-pretrained model. This is the zero-shot performance.  \n",
    "Step 2: Evalute after doing linear probing.  \n",
    "Step 3: Evaluate after doing full finetuning.  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load ETTh2 data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = \"ETTh2\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading target dataset: ETTh2\n"
     ]
    }
   ],
   "source": [
    "print(f\"Loading target dataset: {dataset}\")\n",
    "dataset_path = f\"https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/{dataset}.csv\"\n",
    "timestamp_column = \"date\"\n",
    "id_columns = []\n",
    "forecast_columns = [\"HUFL\", \"HULL\", \"MUFL\", \"MULL\", \"LUFL\", \"LULL\", \"OT\"]\n",
    "train_start_index = None  # None indicates beginning of dataset\n",
    "train_end_index = 12 * 30 * 24\n",
    "\n",
    "# we shift the start of the evaluation period back by context length so that\n",
    "# the first evaluation timestamp is immediately following the training data\n",
    "valid_start_index = 12 * 30 * 24 - context_length\n",
    "valid_end_index = 12 * 30 * 24 + 4 * 30 * 24\n",
    "\n",
    "test_start_index = 12 * 30 * 24 + 4 * 30 * 24 - context_length\n",
    "test_end_index = 12 * 30 * 24 + 8 * 30 * 24"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TimeSeriesPreprocessor {\n",
       "  \"context_length\": 64,\n",
       "  \"feature_extractor_type\": \"TimeSeriesPreprocessor\",\n",
       "  \"id_columns\": [],\n",
       "  \"input_columns\": [\n",
       "    \"HUFL\",\n",
       "    \"HULL\",\n",
       "    \"MUFL\",\n",
       "    \"MULL\",\n",
       "    \"LUFL\",\n",
       "    \"LULL\",\n",
       "    \"OT\"\n",
       "  ],\n",
       "  \"output_columns\": [\n",
       "    \"HUFL\",\n",
       "    \"HULL\",\n",
       "    \"MUFL\",\n",
       "    \"MULL\",\n",
       "    \"LUFL\",\n",
       "    \"LULL\",\n",
       "    \"OT\"\n",
       "  ],\n",
       "  \"prediction_length\": null,\n",
       "  \"processor_class\": \"TimeSeriesPreprocessor\",\n",
       "  \"scaler_dict\": {\n",
       "    \"0\": {\n",
       "      \"copy\": true,\n",
       "      \"feature_names_in_\": [\n",
       "        \"HUFL\",\n",
       "        \"HULL\",\n",
       "        \"MUFL\",\n",
       "        \"MULL\",\n",
       "        \"LUFL\",\n",
       "        \"LULL\",\n",
       "        \"OT\"\n",
       "      ],\n",
       "      \"mean_\": [\n",
       "        41.53683496078959,\n",
       "        12.273452896210882,\n",
       "        46.60977329964991,\n",
       "        10.526153112865156,\n",
       "        1.1869920139097505,\n",
       "        -2.373217913729173,\n",
       "        26.872023494265697\n",
       "      ],\n",
       "      \"n_features_in_\": 7,\n",
       "      \"n_samples_seen_\": 8640,\n",
       "      \"scale_\": [\n",
       "        10.448841072588488,\n",
       "        4.587112566531959,\n",
       "        16.858190332598408,\n",
       "        3.018605566682919,\n",
       "        4.641011217319063,\n",
       "        8.460910779279644,\n",
       "        11.584718923414682\n",
       "      ],\n",
       "      \"var_\": [\n",
       "        109.17827976021215,\n",
       "        21.04160169803542,\n",
       "        284.19858129011436,\n",
       "        9.111979567209104,\n",
       "        21.538985119281367,\n",
       "        71.58701121493046,\n",
       "        134.20571253452223\n",
       "      ],\n",
       "      \"with_mean\": true,\n",
       "      \"with_std\": true\n",
       "    }\n",
       "  },\n",
       "  \"scaling\": true,\n",
       "  \"time_series_task\": \"forecasting\",\n",
       "  \"timestamp_column\": \"date\"\n",
       "}"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = pd.read_csv(\n",
    "    dataset_path,\n",
    "    parse_dates=[timestamp_column],\n",
    ")\n",
    "\n",
    "train_data = select_by_index(\n",
    "    data,\n",
    "    id_columns=id_columns,\n",
    "    start_index=train_start_index,\n",
    "    end_index=train_end_index,\n",
    ")\n",
    "valid_data = select_by_index(\n",
    "    data,\n",
    "    id_columns=id_columns,\n",
    "    start_index=valid_start_index,\n",
    "    end_index=valid_end_index,\n",
    ")\n",
    "test_data = select_by_index(\n",
    "    data,\n",
    "    id_columns=id_columns,\n",
    "    start_index=test_start_index,\n",
    "    end_index=test_end_index,\n",
    ")\n",
    "\n",
    "tsp = TimeSeriesPreprocessor(\n",
    "    timestamp_column=timestamp_column,\n",
    "    id_columns=id_columns,\n",
    "    input_columns=forecast_columns,\n",
    "    output_columns=forecast_columns,\n",
    "    scaling=True,\n",
    ")\n",
    "tsp.train(train_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = ForecastDFDataset(\n",
    "    tsp.preprocess(train_data),\n",
    "    id_columns=id_columns,\n",
    "    input_columns=forecast_columns,\n",
    "    output_columns=forecast_columns,\n",
    "    context_length=context_length,\n",
    "    prediction_length=forecast_horizon,\n",
    ")\n",
    "valid_dataset = ForecastDFDataset(\n",
    "    tsp.preprocess(valid_data),\n",
    "    id_columns=id_columns,\n",
    "    input_columns=forecast_columns,\n",
    "    output_columns=forecast_columns,\n",
    "    context_length=context_length,\n",
    "    prediction_length=forecast_horizon,\n",
    ")\n",
    "test_dataset = ForecastDFDataset(\n",
    "    tsp.preprocess(test_data),\n",
    "    id_columns=id_columns,\n",
    "    input_columns=forecast_columns,\n",
    "    output_columns=forecast_columns,\n",
    "    context_length=context_length,\n",
    "    prediction_length=forecast_horizon,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Zero-shot forecasting on `ETTh2`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading pretrained model\n",
      "Done\n"
     ]
    }
   ],
   "source": [
    "print(\"Loading pretrained model\")\n",
    "finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained(\n",
    "    \"patchtsmixer/electricity/model/pretrain/\"\n",
    ")\n",
    "print(\"Done\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Doing zero-shot forecasting on target data\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='22' max='11' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [11/11 02:52]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Target data zero-shot forecasting result:\n",
      "{'eval_loss': 0.3038313388824463, 'eval_runtime': 1.8364, 'eval_samples_per_second': 1516.562, 'eval_steps_per_second': 5.99}\n"
     ]
    }
   ],
   "source": [
    "finetune_forecast_args = TrainingArguments(\n",
    "    output_dir=\"./checkpoint/patchtsmixer/transfer/finetune/output/\",\n",
    "    overwrite_output_dir=True,\n",
    "    learning_rate=0.0001,\n",
    "    num_train_epochs=100,\n",
    "    do_eval=True,\n",
    "    evaluation_strategy=\"epoch\",\n",
    "    per_device_train_batch_size=batch_size,\n",
    "    per_device_eval_batch_size=batch_size,\n",
    "    dataloader_num_workers=num_workers,\n",
    "    report_to=\"tensorboard\",\n",
    "    save_strategy=\"epoch\",\n",
    "    logging_strategy=\"epoch\",\n",
    "    save_total_limit=3,\n",
    "    logging_dir=\"./checkpoint/patchtsmixer/transfer/finetune/logs/\",  # Make sure to specify a logging directory\n",
    "    load_best_model_at_end=True,  # Load the best model when training ends\n",
    "    metric_for_best_model=\"eval_loss\",  # Metric to monitor for early stopping\n",
    "    greater_is_better=False,  # For loss\n",
    ")\n",
    "\n",
    "# Create a new early stopping callback with faster convergence properties\n",
    "early_stopping_callback = EarlyStoppingCallback(\n",
    "    early_stopping_patience=5,  # Number of epochs with no improvement after which to stop\n",
    "    early_stopping_threshold=0.001,  # Minimum improvement required to consider as improvement\n",
    ")\n",
    "\n",
    "finetune_forecast_trainer = Trainer(\n",
    "    model=finetune_forecast_model,\n",
    "    args=finetune_forecast_args,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=valid_dataset,\n",
    "    callbacks=[early_stopping_callback],\n",
    ")\n",
    "\n",
    "print(\"\\n\\nDoing zero-shot forecasting on target data\")\n",
    "result = finetune_forecast_trainer.evaluate(test_dataset)\n",
    "print(\"Target data zero-shot forecasting result:\")\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Target data `ETTh2` linear probing\n",
    "We can do a quick linear probing on the `train` part of the target data to see any possible `test` performance improvement. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Linear probing on the target data\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='416' max='3200' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [ 416/3200 01:01 < 06:53, 6.73 it/s, Epoch 13/100]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.447000</td>\n",
       "      <td>0.216436</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.438600</td>\n",
       "      <td>0.215667</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.429400</td>\n",
       "      <td>0.215104</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.422500</td>\n",
       "      <td>0.213820</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.418500</td>\n",
       "      <td>0.213585</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>0.415000</td>\n",
       "      <td>0.213016</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>0.412000</td>\n",
       "      <td>0.213067</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>0.412400</td>\n",
       "      <td>0.211993</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>0.405900</td>\n",
       "      <td>0.212460</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>0.405300</td>\n",
       "      <td>0.211772</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11</td>\n",
       "      <td>0.406200</td>\n",
       "      <td>0.212154</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12</td>\n",
       "      <td>0.400600</td>\n",
       "      <td>0.212082</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13</td>\n",
       "      <td>0.405300</td>\n",
       "      <td>0.211458</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='11' max='11' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [11/11 00:00]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Target data head/linear probing result:\n",
      "{'eval_loss': 0.27119266986846924, 'eval_runtime': 1.7621, 'eval_samples_per_second': 1580.478, 'eval_steps_per_second': 6.242, 'epoch': 13.0}\n"
     ]
    }
   ],
   "source": [
    "# Freeze the backbone of the model\n",
    "for param in finetune_forecast_trainer.model.model.parameters():\n",
    "    param.requires_grad = False\n",
    "\n",
    "print(\"\\n\\nLinear probing on the target data\")\n",
    "finetune_forecast_trainer.train()\n",
    "print(\"Evaluating\")\n",
    "result = finetune_forecast_trainer.evaluate(test_dataset)\n",
    "print(\"Target data head/linear probing result:\")\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['patchtsmixer/electricity/model/transfer/ETTh2/preprocessor/preprocessor_config.json']"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "save_dir = f\"patchtsmixer/electricity/model/transfer/{dataset}/model/linear_probe/\"\n",
    "os.makedirs(save_dir, exist_ok=True)\n",
    "finetune_forecast_trainer.save_model(save_dir)\n",
    "\n",
    "save_dir = f\"patchtsmixer/electricity/model/transfer/{dataset}/preprocessor/\"\n",
    "os.makedirs(save_dir, exist_ok=True)\n",
    "tsp.save_pretrained(save_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Target data `ETTh2` full finetune\n",
    "\n",
    "We can do a full model finetune (instead of probing the last linear layer as shown above) on the `train` part of the target data to see a possible `test` performance improvement."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Finetuning on the target data\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='288' max='3200' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [ 288/3200 00:44 < 07:34, 6.40 it/s, Epoch 9/100]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.432900</td>\n",
       "      <td>0.215200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.416700</td>\n",
       "      <td>0.210919</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.401400</td>\n",
       "      <td>0.209932</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.392900</td>\n",
       "      <td>0.208808</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.388100</td>\n",
       "      <td>0.209692</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>0.375900</td>\n",
       "      <td>0.209546</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>0.370000</td>\n",
       "      <td>0.210207</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>0.367000</td>\n",
       "      <td>0.211601</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>0.359400</td>\n",
       "      <td>0.211405</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluating\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='11' max='11' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [11/11 00:00]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Target data full finetune result:\n",
      "{'eval_loss': 0.2734043300151825, 'eval_runtime': 1.5853, 'eval_samples_per_second': 1756.725, 'eval_steps_per_second': 6.939, 'epoch': 9.0}\n"
     ]
    }
   ],
   "source": [
    "# Reload the model\n",
    "finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained(\n",
    "    \"patchtsmixer/electricity/model/pretrain/\"\n",
    ")\n",
    "finetune_forecast_trainer = Trainer(\n",
    "    model=finetune_forecast_model,\n",
    "    args=finetune_forecast_args,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=valid_dataset,\n",
    "    callbacks=[early_stopping_callback],\n",
    ")\n",
    "print(\"\\n\\nFinetuning on the target data\")\n",
    "finetune_forecast_trainer.train()\n",
    "print(\"Evaluating\")\n",
    "result = finetune_forecast_trainer.evaluate(test_dataset)\n",
    "print(\"Target data full finetune result:\")\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_dir = f\"patchtsmixer/electricity/model/transfer/{dataset}/model/fine_tuning/\"\n",
    "os.makedirs(save_dir, exist_ok=True)\n",
    "finetune_forecast_trainer.save_model(save_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
